"""
# Copyright (c) 2025  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""

import argparse
import asyncio
import contextlib
import os
import signal
import socket
import subprocess
import time
from typing import Union

import openai
import yaml
from benchmark_dataset import EBChatDataset, EBDataset, SampleRequest
from benchmark_serving import benchmark


def prepare_input_requests(
    num_prompts: int, dataset_name: str, dataset_path: str
) -> Union[EBDataset, EBChatDataset]:
    dataset_mapping = {
        "EB": lambda: EBDataset(dataset_path=dataset_path).sample(
            num_requests=num_prompts
        ),
        "EBChat": lambda: EBChatDataset(dataset_path=dataset_path).sample(
            num_requests=num_prompts
        ),
    }

    try:
        input_requests = dataset_mapping[dataset_name]()
    except KeyError as err:
        raise ValueError(f"Unknown dataset: {dataset_name}") from err

    return input_requests


class FakeTokenizer:
    def encode(self, text: str, add_special_tokens: bool = False):
        return []


def send_one_batch(base_url, max_concurrency, input_requests, disable_tqdm):
    selected_percentile_metrics = ["s_itl"]
    selected_percentiles = []
    # Run benchmark
    results = asyncio.run(
        benchmark(
            backend="openai-chat",
            api_url=f"{base_url}/v1/chat/completions",
            base_url=base_url,
            model_id="default",
            model_name="default",
            input_requests=input_requests,
            hyper_parameters={},
            logprobs=None,
            request_rate=float("inf"),
            burstiness=1.0,
            disable_tqdm=disable_tqdm,
            profile=False,
            selected_percentile_metrics=selected_percentile_metrics,
            selected_percentiles=selected_percentiles,
            ignore_eos=False,
            goodput_config_dict=None,
            max_concurrency=max_concurrency,
            lora_modules=None,
            extra_body=None,
        )
    )

    record = {
        "mean_s_itl_ms": results["mean_s_itl_ms"],
    }

    return record


def calculate_speedup(acceptance_rate, draft_token_step, t_ori, t_mtp):

    tmp = 0.0
    for i in range(draft_token_step):
        tmp += pow(acceptance_rate, i + 1)

    r_ac = tmp / (1 + tmp)

    return t_ori / ((1 - r_ac) * t_mtp)


def main(args):
    base_url = f"http://{args.host}:{args.port}"

    input_requests = prepare_input_requests(
        args.num_prompts, args.dataset_name, args.dataset_path
    )

    if len(args.max_concurrency) != len(args.s_itl_base_model):
        raise ValueError(f"--max_concurrency should be same length as --s_itl_base_model")

    for max_concurrency, s_itl in zip(args.max_concurrency, args.s_itl_base_model):
        # Wramup
        print("Starting warmup...")
        with open(os.devnull, "w") as f:
            with contextlib.redirect_stdout(f):
                send_one_batch(base_url, max_concurrency, input_requests[0:max_concurrency], True)

        # Benchmark
        record = send_one_batch(base_url, max_concurrency, input_requests, False)

        metric_header = f"Speed up"
        print("{s:{c}^{n}}".format(s=metric_header, n=50, c="-"))
        for draft_token_step in args.draft_token_steps:
            speedup = calculate_speedup(
                args.acceptance_rate,
                draft_token_step,
                s_itl,
                record["mean_s_itl_ms"],
            )
            print(
                "{:<40} {:<10.2f}".format(
                    f"Speed up on {draft_token_step} steps draft", speedup
                )
            )
        print("=" * 50)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--host",
        type=str,
        default="127.0.0.1",
    )
    parser.add_argument(
        "--port",
        type=str,
        default="8000",
    )
    parser.add_argument(
        "--max-concurrency",
        type=int,
        nargs="+",
        default=(1, 2, 4, 8, 16, 32),
    )
    parser.add_argument(
        "--num-prompts",
        type=int,
        default=128,
    )
    parser.add_argument(
        "--acceptance-rate",
        type=float,
        default=0.8,
    )
    parser.add_argument(
        "--draft-token-steps",
        type=int,
        nargs="+",
        default=(1, 2),
    )
    parser.add_argument(
        "--s_itl-base-model",
        type=float,
        nargs="+",
    )
    parser.add_argument(
        "--dataset-name",
        type=str,
        default="EBChat",
    )
    parser.add_argument(
        "--dataset-path",
        type=str,
    )
    args = parser.parse_args()

    main(args)
